__author__ = 'JackSong'
import math
import base64
import logging
import datetime
import traceback
import tornado.websocket
from model.pool import Pool
from tornado.gen import coroutine, Return
from tornado.options import options
from tornado.web import asynchronous
from utils.cephRados import CephConnection
from utils.data import ClientData
from baseHandler import BaseRequestHandler, BaseWebSocketHandler
from tornado.websocket import websocket_connect, WebSocketClosedError
from tornado.concurrent import run_on_executor
from tornado.ioloop import IOLoop, PeriodicCallback
from middleware.authMiddleware import PermissionCheckMiddleware, CheckPoolMiddleware
from sqlalchemy.orm.exc import NoResultFound
from tornado.web import HTTPError

logger = logging.getLogger()


class From(object):
    def __init__(self, obj):
        self.obj = obj

    def __getattr__(self, item):
        argument = self.obj.get_argument(item, None)
        return argument


class LargeFilesHandler(BaseRequestHandler):
    @asynchronous
    @coroutine
    def post(self):
        f = From(self)
        file_metas = self.request.files.get('file')
        for meta in file_metas:
            filename = meta['filename']
            offset = f.offset if f.offset else 0
            rest = yield self.ceph.async_write_full(f.pool, filename, meta.get('body'), self._id,offset=int(offset))
            if not isinstance(rest, str):
                if rest.is_safe():
                    self.write('safe write!')
            else:
                self.write(rest)


class FileHandler(BaseRequestHandler):
    @coroutine
    def get(self):
        poolname = self.get_argument('poolname')
        obj = self.get_argument('object', None)
        if obj:
            stat = yield self.ceph.object_info(poolname, obj)
            self.write(stat)
        else:
            obj_list = yield self.ceph.obj_list(poolname)
            self.write(obj_list)

    @asynchronous
    @coroutine
    def post(self):
        file_metas = self.request.files.get('file')
        pool = self.get_argument('poolname')
        if file_metas and pool:
            for meta in file_metas:
                filename = meta['filename']
                rest = yield self.ceph.async_write_full(pool, filename, meta.get('body'),self._id)
                if not isinstance(rest, str):
                    if rest.is_safe():
                        self.write('safe write!')
                        yield self.flush()
                else:
                    self.write(rest)
        else:
            self.write('not find file or pool')

    @coroutine
    def delete(self):
        pool = self.get_argument('poolname')
        file_name = self.get_argument('filename')
        res = yield self.ceph.delete_pool_object(pool, file_name)
        self.write(res)


class DownloadHandler(BaseRequestHandler):
    middleware = (CheckPoolMiddleware(),)

    @property
    @coroutine
    def middlewares(self):
        poolname = self.get_argument('poolname')

        def _lambda(db):
            try:
                pool = db.query(Pool).filter_by(name=poolname).one()
                self.pool = pool
                if pool.policy == 2:
                    return list()
                else:
                    return super(DownloadHandler, self).middlewares
            except NoResultFound:
                class HttpError(object):
                    def process_request(self, request):
                        raise HTTPError(404, 'not find pool')

                return (HttpError(),)

        res = yield self.asyncdb(_lambda)
        raise Return(res)

    @asynchronous
    @coroutine
    def get(self):
        pool = self.get_argument('poolname')
        filename = self.get_argument('filename')
        if not self.pool.html or filename[-4:] not in ('html', '.htm',):
            self.set_header('Content-Type', 'application/octet-stream')
            self.set_header('Content-Disposition', 'attachment; filename=' + filename)
        stat = yield self.ceph.object_info(pool, filename)
        if stat.get('status') in ('OK',):
            size, timestamp = stat.get('output')
            self.set_header("Content-Length", size)
            pagesize = 1024 * 1024 * 4
            offset = 0
            if pagesize > size:
                pagesize = size
            while offset < size:
                data = yield self.ceph.async_read_full(pool, filename,self._id,pagesize,offset)
                offset += pagesize
                self.write(data)
                yield self.flush()
        else:
            raise HTTPError(404)


class WSObjectUploadHandlerWeb(BaseWebSocketHandler):
    Clients = dict()

    def open(self):
        self.put_client()

    def put_client(self, data=None):
        if data:
            self.Clients[self._id()] = data
        else:
            self.Clients[self._id()] = {"ceph": CephConnection()}

    def get_client(self):
        return self.Clients.get(self._id(), None)

    @coroutine
    def on_message(self, message):
        client = self.get_client()
        if self.is_init(message):
            data = ClientData(message)
            client.update({'poolname': data.poolname, 'filename': data.filename, 'offset': data.offset})
            self.put_client(client)
            self.write_message('success')
        else:
            message = message.encode('utf-8')
            salt = 'data:;base64,'
            if salt in message:
                split = message.split(salt)
                message = split[1]
            message = base64.b64decode(message)
            self.upload_object(message)

    @coroutine
    def upload_object(self, message):
        state = True
        client = self.get_client()
        while state:
            offset = int(client.get('offset'))
            filename = client.get('filename')
            try:
                rest = yield client.get('ceph').async_write_full(pool=client.get('poolname'), name=filename,
                                                                 obj_id=self._id,
                                                                 data=message, offset=offset)
                if rest.is_complete():
                    offset += len(message)
                    client.update({'offset': offset})
                    try:
                        self.write_message('success')
                    except tornado.websocket.WebSocketError:
                        self.on_close()
                    state = False
            except Exception as e:
                logger.info(e)
                traceback.print_exc()
                break

    def on_close(self):
        client = self.get_client()
        if client:
            pool = client.get('ceph')
            del pool

    def _id(self):
        return id(self)

    def is_init(self, data):
        if 'filename' in data:
            return True
        return False

    def check_origin(self, origin):
        return True


class WSRBDHandlerWeb(BaseWebSocketHandler):
    Clients = dict()

    def open(self):
        self.put_client()

    def put_client(self, data=None):
        if data:
            self.Clients[self._id()] = data
        else:
            self.Clients[self._id()] = {"ceph": CephConnection()}

    def get_client(self):
        return self.Clients.get(self._id(), None)

    @coroutine
    def on_message(self, message):
        client = self.get_client()
        if self.is_init(message):
            data = ClientData(message)
            client.update(
                {'poolname': data.poolname, 'filename': data.filename, 'offset': data.offset, 'size': data.size})
            self.put_client(client)
            self.write_message('success')
            yield self.ceph.create_rbd(data.poolname, data.filename, data.size)
        else:
            message = message.encode('utf-8')
            salt = 'data:;base64,'
            if salt in message:
                split = message.split(salt)
                message = split[1]
            message = base64.b64decode(message)
            yield self.upload_object(message)

    @coroutine
    def upload_object(self, message):
        state = True
        client = self.get_client()
        while state:
            offset = int(client.get('offset'))
            filename = client.get('filename')
            try:
                rest = yield self.executor.submit(client.get('ceph').write_rbd, *(client.get('poolname'), filename,
                                                                                  message, offset))
                offset += rest
                client.update({'offset': offset})
                try:
                    self.write_message('success')
                except tornado.websocket.WebSocketError:
                    self.on_close()
                state = False
            except Exception as e:
                logger.info(e)
                traceback.print_exc()
                break

    def on_close(self):
        pool = self.get_client().get('ceph')
        del pool


    def is_init(self, data):
        if 'filename' in data:
            return True
        return False


class SocketConnection(object):
    CLIENTS = set()

    def __init__(self, stream, address):
        self.CLIENTS.add(self)
        self._stream = stream
        self._address = address
        self._stream.set_close_callback(self.on_close)
        self.read_message()

    def read_message(self):
        self._stream.read_until('\n', self.broadcast_messages)

    def broadcast_messages(self, data):
        for conn in self.CLIENTS:
            conn.send_message(data)
        self.read_message()

    def send_message(self, data):
        self._stream.write(data)

    def on_close(self):
        self.CLIENTS.remove(self)


class SyncObjectHandlerWeb(BaseWebSocketHandler):
    def check_is_init(self, message):
        if 'source' in message:
            return True
        else:
            return False

    @coroutine
    def on_message(self, message):
        if self.check_is_init(message):
            self.init_data = ClientData(data=message)
            self.cache_key = u'{0}-{1}-{2}-{3}'.format(self.init_data.source_poolname,
                                                       self.init_data.source_obj,
                                                       self.init_data.destination_poolname,
                                                       self.init_data.destination_obj)
            block_num, _, current_num, _, _, _ = yield self.get_object_info()
            if self.check_process(block_num, current_num):
                if hasattr(self, 'timers'):
                    self.timers.stop()
                self.timers = PeriodicCallback(self.get_process_status, 1000)
                self.timers.start()
            else:
                yield self.copy_object(self.write_message)
                if hasattr(self, 'close'):
                    self.close()

    def check_process(self, block_num, current_num):
        return block_num > current_num and block_num != 0 and current_num != 0

    @coroutine
    def get_process_status(self):
        block_num, block_size, current_num, offset, size, before_time = yield self.get_object_info()
        if self.check_process(block_num, current_num):
            try:
                self.write_message(self.get_schedule(before_time, block_num, block_size, current_num))
                if block_num == current_num:
                    self.timers.stop()
                    if hasattr(self.ws_connection, 'close'):
                        self.ws_connection.close()
            except WebSocketClosedError as e:
                logger.info(e)
                self.timers.stop()
                if self.ws_connection and hasattr(self.ws_connection, '__call__'):
                    self.ws_connection.close()
        else:
            self.timers.stop()
            yield self.del_cache()
            if hasattr(self.ws_connection, 'close'):
                self.ws_connection.close()

    def get_schedule(self, before_time, block_num, block_size, current_num):
        if not isinstance(before_time, int):
            before_time = int(before_time)
        if before_time == 0:
            speed = 0
        else:
            speed = round(float(block_size) / float(before_time), 4) * 1000
        schedule = round(float(current_num) / float(block_num), 4) * 100
        return {'schedule': '%.2f%%' % (schedule),
                'speed': '{0}kb/s'.format(speed)}

    @coroutine
    def get_object_info(self):
        before_process = yield self.cache('GET', self.cache_key)
        if before_process:
            params = []
            for var in before_process.split('-'):
                params.append(int(var))
            raise Return(tuple(params))
        else:
            file_info = yield self.ceph.object_info(self.init_data.source_poolname,
                                                    self.init_data.source_obj)
            block_size = options.ceph_block_size
            if file_info.get('status', None) in ('OK',):
                size = int(file_info.get('output')[0])
                offset = 0
                if size > block_size:
                    block_num = int(math.ceil(size / block_size))
                else:
                    block_num = 1
                    block_size = size
                current_num = 0
                raise Return((block_num + 1, block_size, current_num, offset, size, 0,))
            else:
                raise Return((0, 0, 0, 0, 0, 0,))

    @coroutine
    def on_close(self, flap=True):
        if hasattr(self, 'timers'):
            self.timers.stop()

    @run_on_executor
    @coroutine
    def background_process(self):
        yield self.copy_object(callback=None)

    @coroutine
    def copy_object(self, callback=None):
        conn = yield websocket_connect(self.init_data.url)
        block_num, block_size, current_num, offset, size, before_time = yield self.get_object_info()
        yield conn.write_message(
            unicode({u"filename": self.init_data.destination_obj, u"poolname": self.init_data.destination_poolname,
                     u"size": size,
                     u"offset": offset}))
        res = yield conn.read_message()
        if res in ('success',):
            try:
                while current_num < block_num:
                    time = datetime.datetime.now()
                    data = yield self.ceph.async_read_full(self.init_data.source_poolname, self.init_data.source_obj,
                                                           self._id,block_size, offset)
                    if data:
                        data = base64.b64encode(data)
                        yield conn.write_message(data)
                        res = yield conn.read_message()
                        if res in ('success',):
                            seconds = (time - datetime.datetime.now()).seconds
                            yield self.set_cache(block_num, block_size, current_num, offset, seconds, size)

                            current_num += 1
                            offset = offset + block_size
                            if offset + block_size > size:
                                block_size = size - offset
                                if block_size <= 0:
                                    current_num += 1
                        yield self.set_cache(block_num, block_size, current_num, offset, seconds, size)
                    if callback:
                        if hasattr(callback, "__call__"): yield callback(
                            self.get_schedule(seconds, block_num, block_size, current_num))
                yield self.set_cache(block_num, block_size, current_num, offset, seconds, size)
            except WebSocketClosedError as e:
                logger.info(e)
                if block_num > current_num:
                    IOLoop.instance().add_callback(self.background_process)
            except TypeError as e:
                logger.info(e)
                if callback:
                    if hasattr(callback, "__call__"): yield callback(
                        self.get_schedule(before_time, block_num, block_size, current_num))
                if block_num > current_num:
                    IOLoop.instance().add_callback(self.background_process)
            finally:
                if self.check_process(block_num, current_num):
                    IOLoop.instance().add_callback(self.background_process)
                else:
                    yield self.del_cache()

    @coroutine
    def set_cache(self, block_num, block_size, current_num, offset, seconds, size):
        yield self.cache('SET', self.cache_key,
                         u'{0}-{1}-{2}-{3}-{4}-{5}'.format(block_num, block_size, current_num, offset,
                                                           size, seconds))
        yield self.cache("EXPIRE", self.cache_key, 60)

    @coroutine
    def del_cache(self):
        yield self.cache('DEL', self.cache_key, )
